///////////////////////////////////////////////////////////////////////////////
// vim:et:ts=2:sw=2:ci:cino=f0,g0,t0,+0:
//                                                                           //
// The Template Matrix/Vector Library for C++ was created by Mike Jarvis     //
// Copyright (C) 1998 - 2009                                                 //
//                                                                           //
// The project is hosted at http://sourceforge.net/projects/tmv-cpp/         //
// where you can find the current version and current documention.           //
//                                                                           //
// For concerns or problems with the software, Mike may be contacted at      //
// mike_jarvis@users.sourceforge.net                                         //
//                                                                           //
// This program is free software; you can redistribute it and/or             //
// modify it under the terms of the GNU General Public License               //
// as published by the Free Software Foundation; either version 2            //
// of the License, or (at your option) any later version.                    //
//                                                                           //
// This program is distributed in the hope that it will be useful,           //
// but WITHOUT ANY WARRANTY; without even the implied warranty of            //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             //
// GNU General Public License for more details.                              //
//                                                                           //
// You should have received a copy of the GNU General Public License         //
// along with this program in the file LICENSE.                              //
//                                                                           //
// If not, write to:                                                         //
// The Free Software Foundation, Inc.                                        //
// 51 Franklin Street, Fifth Floor,                                          //
// Boston, MA  02110-1301, USA.                                              //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////


//#define XDEBUG


#include "tmv/TMV_BandMatrixArithFunc.h"
#include "tmv/TMV_BandMatrix.h"
#include "tmv/TMV_VectorArith.h"
#include "tmv/TMV_BandMatrixArith.h"

#ifdef XDEBUG
#include "tmv/TMV_MatrixArith.h"
#include <iostream>
using std::cout;
using std::cerr;
using std::endl;
#endif

namespace tmv {

  //
  // AddMM
  //

  template <class T, class Ta> static void DoAddMM(
      const T alpha, const GenBandMatrix<Ta>& A, const BandMatrixView<T>& B)
  { 
    TMVAssert(A.colsize() == B.colsize());
    TMVAssert(A.rowsize() == B.rowsize());
    TMVAssert(B.nlo() >= A.nlo());
    TMVAssert(B.nhi() >= A.nhi());
    TMVAssert(alpha != T(0));
    TMVAssert(B.colsize() > 0);
    TMVAssert(B.rowsize() > 0);
    TMVAssert(B.ct() == NonConj);
    TMVAssert(!SameStorage(A,B));

    if (A.stor() == B.stor() && A.nlo() == B.nlo() && A.nhi() == B.nhi() && 
        A.CanLinearize() && B.CanLinearize()) {
      TMVAssert(A.stepi() == B.stepi() && A.stepj() == B.stepj());
      B.LinearView() += alpha*A.ConstLinearView();
    } else {
      for(int i=-A.nlo();i<=A.nhi();++i) {
        B.diag(i) += alpha * A.diag(i);
      }
    }
  }

  template <class T, class Ta> void AddMM(const T alpha,
      const GenBandMatrix<Ta>& A, const BandMatrixView<T>& B)
  // B += alpha * A
  {
#ifdef XDEBUG
    //cout<<"Band AddMM: alpha = "<<alpha<<endl;
    //cout<<"A = "<<TypeText(A)<<"  "<<A<<endl;
    //cout<<"B = "<<TypeText(B)<<"  "<<B<<endl;
    Matrix<Ta> A0 = A;
    Matrix<T> B0 = B;
    Matrix<T> B2 = B0 + alpha*A0;
#endif
    TMVAssert(A.colsize() == B.colsize());
    TMVAssert(A.rowsize() == B.rowsize());
    TMVAssert(B.nlo() >= A.nlo());
    TMVAssert(B.nhi() >= A.nhi());

    if (B.colsize() > 0 && B.rowsize() > 0) {
      if (B.isconj())
        AddMM(CONJ(alpha),A.Conjugate(),B.Conjugate());
      else {
        if (SameStorage(A,B)) {
          if (B.isrm()) {
            BandMatrix<Ta,RowMajor> A2 = A;
            DoAddMM(alpha,A2,B);
          } else if (B.iscm()) {
            BandMatrix<Ta,ColMajor> A2 = A;
            DoAddMM(alpha,A2,B);
          } else {
            BandMatrix<Ta,DiagMajor> A2 = A;
            DoAddMM(alpha,A2,B);
          }
        } 
        else DoAddMM(alpha,A,B);
      }
    }
#ifdef XDEBUG
    if (Norm(B2-Matrix<T>(B)) > 0.001*ABS(alpha)*Norm(A0)*Norm(B0)) {
      cerr<<"Band AddMM\n";
      cerr<<"alpha = "<<alpha<<endl;
      cerr<<"A = "<<TypeText(A)<<"  "<<A.cptr()<<"   "<<A<<endl;
      cerr<<"B = "<<TypeText(B)<<"  "<<B.cptr()<<"   "<<B0<<endl;
      cerr<<"B => "<<B<<endl;
      cerr<<"B2 = "<<B2<<endl;
      abort();
    }
#endif
  }

  template <class T, class Ta, class Tb> void AddMM(const T alpha,
      const GenBandMatrix<Ta>& A, const T beta, const GenBandMatrix<Tb>& B,
      const BandMatrixView<T>& C)
  {
#ifdef XDEBUG
    //cout<<"Band AddMM: alpha = "<<alpha<<", beta = "<<beta<<endl;
    //cout<<"A = "<<TypeText(A)<<"  "<<A<<endl;
    //cout<<"B = "<<TypeText(B)<<"  "<<B<<endl;
    //cout<<"C = "<<TypeText(C)<<endl;
    Matrix<Ta> A0 = A;
    Matrix<Tb> B0 = B;
    Matrix<T> C2 = alpha*A0+beta*B0;
#endif
    TMVAssert(A.colsize() == B.colsize());
    TMVAssert(A.rowsize() == B.rowsize());
    TMVAssert(A.colsize() == C.colsize());
    TMVAssert(A.rowsize() == C.rowsize());
    TMVAssert(C.nlo() >= A.nlo());
    TMVAssert(C.nhi() >= A.nhi());
    TMVAssert(C.nlo() >= B.nlo());
    TMVAssert(C.nhi() >= B.nhi());

    if (C.isconj()) {
      AddMM(CONJ(alpha),A.Conjugate(),CONJ(beta),B.Conjugate(),C.Conjugate());
    }
    else {
      if (B.colsize() > 0 && B.rowsize() > 0) {
        if (SameStorage(A,C)) {
          if (SameStorage(B,C)) {
            if (B.isrm()) {
              BandMatrix<T,RowMajor> tempB = B;
              C = alpha*A;
              DoAddMM(beta,tempB,C);
            } else if (C.iscm()) {
              BandMatrix<T,ColMajor> tempB = B;
              C = alpha*A;
              DoAddMM(beta,tempB,C);
            } else {
              BandMatrix<T,DiagMajor> tempB = B;
              C = alpha*A;
              DoAddMM(beta,tempB,C);
            }
          } else {
            C = alpha*A;
            DoAddMM(beta,B,C);
          }
        } else {
          C = beta*B;
          DoAddMM(alpha,A,C);
        }
      }
    }
    //cout<<"Done C => "<<C<<endl;
#ifdef XDEBUG
    if (Norm(C2-Matrix<T>(C)) > 
        0.001*(ABS(alpha)*Norm(A0)+ABS(beta)*Norm(B0))) {
      cerr<<"Band AddMM\n";
      cerr<<"alpha,beta = "<<alpha<<","<<beta<<endl;
      cerr<<"A = "<<TypeText(A)<<"  "<<A.cptr()<<"   "<<A0<<endl;
      cerr<<"B = "<<TypeText(B)<<"  "<<B.cptr()<<"   "<<B0<<endl;
      cerr<<"C = "<<TypeText(C)<<"  "<<C.cptr()<<" ->  "<<C<<endl;
      cerr<<"C2 = "<<C2<<endl;
      abort();
    }
#endif
  }

  template <class T, class Ta, class Tb> void AddMM(const T alpha,
      const GenBandMatrix<Ta>& A, const T beta, const GenMatrix<Tb>& B,
      const MatrixView<T>& C)
  {
#ifdef XDEBUG
    //cout<<"Band AddMM: alpha = "<<alpha<<", beta = "<<beta<<endl;
    //cout<<"A = "<<TypeText(A)<<"  "<<A.cptr()<<"  "<<A<<endl;
    //cout<<"B = "<<TypeText(B)<<"  "<<B.cptr()<<"  "<<B<<endl;
    //cout<<"C = "<<TypeText(C)<<"  "<<C.cptr()<<endl;
    Matrix<Ta> A0 = A;
    Matrix<Tb> B0 = B;
    Matrix<T> C2 = alpha*A0+beta*B0;
#endif
    TMVAssert(A.colsize() == B.colsize());
    TMVAssert(A.rowsize() == B.rowsize());
    TMVAssert(A.colsize() == C.colsize());
    TMVAssert(A.rowsize() == C.rowsize());

    if (C.isconj()) {
      AddMM(CONJ(alpha),A.Conjugate(),beta,B.Conjugate(),C.Conjugate());
    } else {
      if (C.colsize() > 0 && C.rowsize() > 0) {
        if (SameStorage(A,C)) {
          if (SameStorage(B,C)) {
            if (A.isrm()) {
              BandMatrix<Ta,RowMajor> tempA = A;
              C = beta*B;
              DoAddMM(alpha,tempA,BandMatrixView<T>(C,A.nlo(),A.nhi()));
            } else if (A.iscm()) {
              BandMatrix<Ta,ColMajor> tempA = A;
              C = beta*B;
              DoAddMM(alpha,tempA,BandMatrixView<T>(C,A.nlo(),A.nhi()));
            } else {
              BandMatrix<Ta,DiagMajor> tempA = A;
              C = beta*B;
              DoAddMM(alpha,tempA,BandMatrixView<T>(C,A.nlo(),A.nhi()));
            }
          } else {
            C = alpha*A;
            AddMM(beta,B,C);
          }
        } else {
          C = beta*B;
          DoAddMM(alpha,A,BandMatrixView<T>(C,A.nlo(),A.nhi()));
        }
      }
    }
#ifdef XDEBUG
    //cout<<"Done C => "<<C<<endl;
    if (Norm(C2-Matrix<T>(C)) > 
        0.001*(ABS(alpha)*Norm(A0)+ABS(beta)*Norm(B0))) {
      cerr<<"Band AddMM\n";
      cerr<<"alpha = "<<alpha<<","<<beta<<endl;
      cerr<<"A = "<<TypeText(A)<<"  "<<A.cptr()<<"   "<<A0<<endl;
      cerr<<"B = "<<TypeText(B)<<"  "<<B.cptr()<<"   "<<B0<<endl;
      cerr<<"C = "<<TypeText(C)<<"  "<<C.cptr()<<" ->  "<<C<<endl;
      cerr<<"C2 = "<<C2<<endl;
      abort();
    }
#endif
  }

#define InstFile "TMV_AddBB.inst"
#include "TMV_Inst.h"
#undef InstFile

} // namespace tmv


